{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data folder is C:\\Users\\lfranceschi\\DATASETS\n",
      "Experiment save directory is  C:\\Users\\lfranceschi\\EXPERIMENTS\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import experiment_manager as em\n",
    "import far_ho as far"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Checking online hyperparameter optimization for `ReverseHG`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _test(method):\n",
    "    try: ss.close()\n",
    "    except: pass\n",
    "    tf.reset_default_graph()\n",
    "    ss = tf.InteractiveSession()\n",
    "    v1 = tf.Variable([[1., -1., 1.], [2., 10., 2.]])\n",
    "\n",
    "    v2 = tf.Variable([-1., 1., 0.2])\n",
    "    lmbd = far.get_hyperparameter('lambda', \n",
    "                              initializer=tf.ones_initializer,\n",
    "                             shape=v2.get_shape(), scalar=True)\n",
    "\n",
    "    cost = tf.reduce_mean(v1**2) + tf.reduce_sum(lmbd*v2**2)\n",
    "#     cost = tf.reduce_sum(lmbd*v2**2)\n",
    "\n",
    "    \n",
    "    \n",
    "    _et = far.get_hyperparameter('eta', 0.1)\n",
    "    io_optim = far.MomentumOptimizer(_et, far.get_hyperparameter('mu', .5))\n",
    "    # io_optim = far.GradientDescentOptimizer(_et)\n",
    "\n",
    "    io_optim_dict = io_optim.minimize(cost) \n",
    "    oo = tf.reduce_mean(v1*v2)\n",
    "#     oo = tf.reduce_mean(v2)\n",
    "    optim_oo = tf.train.AdamOptimizer(.01)\n",
    "    farho = far.HyperOptimizer(hypergradient=method())\n",
    "    farho.minimize(oo, optim_oo, cost, io_optim)\n",
    "    return ss, farho, cost, oo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### REVERSE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-0.018393531, 0.055180617, 0.0036787055, 0.53872281, -0.076542258]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ss, farho, cost, oo = _test(far.ReverseHg)\n",
    "\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "# execution with gradient descent! (looks ol)\n",
    "farho.run(10)\n",
    "\n",
    "hypergrads_rev = ss.run(far.utils.hypergradients())\n",
    "hypergrads_rev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(farho.hypergradient._history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[20.540001,\n",
       " 18.592821,\n",
       " 16.177711,\n",
       " 13.956007,\n",
       " 12.062778,\n",
       " 10.446499,\n",
       " 9.0413418,\n",
       " 7.8095841,\n",
       " 6.7327967,\n",
       " 5.7977762]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "farho.hypergradient.inner_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-0.018393531, 0.055180617, 0.0036787055, 0.53872281, -0.076542258]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ss, farho, cost, oo = _test(far.ReverseHg)\n",
    "\n",
    "tf.global_variables_initializer().run()\n",
    "\n",
    "# execution with gradient descent! (looks ol)\n",
    "farho.run(10, online=True, _skip_hyper_ts=True)\n",
    "\n",
    "hypergrads_online = ss.run(far.utils.hypergradients())\n",
    "hypergrads_online"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(np.array(hypergrads_online) - np.array(hypergrads_rev)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### OK... SOUNDS GOOD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 0.51925904, -0.51925904,  0.51925904],\n",
       "        [ 1.03851807,  5.19259167,  1.03851807]], dtype=float32),\n",
       " array([ 0.03254529, -0.03254529, -0.00650906], dtype=float32)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ss.run(tf.trainable_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 0.24525888, -0.24525888,  0.24525888],\n",
       "        [ 0.49051777,  2.45258904,  0.49051777]], dtype=float32),\n",
       " array([ -2.79949629e-04,   2.79949629e-04,   5.59900072e-05], dtype=float32)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''with online optimization parameters are not reinitialize. It corresponds to performin\n",
    "a warm restart for solving the inner optimization problem.\n",
    "'''\n",
    "farho.run(10, online=True, _skip_hyper_ts=True)\n",
    "\n",
    "ss.run(tf.trainable_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 0.51925904, -0.51925904,  0.51925904],\n",
       "        [ 1.03851807,  5.19259167,  1.03851807]], dtype=float32),\n",
       " array([ 0.03254529, -0.03254529, -0.00650906], dtype=float32)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check that indeed the paramters with batch hyperparameter optimization are different...\n",
    "farho.run(10, online=False,  _skip_hyper_ts=True)\n",
    "\n",
    "ss.run(tf.trainable_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
